use super::mut_void;
use crate::platform::{htons, inet_ntop, inet_pton, ntohs};
use crate::{err, Error, Result};
use core::fmt;
use core::mem::{self, MaybeUninit};
use core::ptr;
use core::slice;
use core::str;

/// 封装sockaddr_in/sockaddr_in6/sockaddr_un三类地址.
/// 它可用于socket相关的数据收发接口(比如recvfrom/sendto/accept).
#[derive(Clone)]
pub struct SocketAddr {
    addr: MaybeUninit<libc::sockaddr_storage>,
}

impl SocketAddr {
    /// 未初始化的变量，可用于任何需要本类型的接口，不会导致内存安全问题
    /// 但非有效地址的内容会导致对应的socket功能接口返回错误.
    pub const fn uninit() -> Self {
        Self {
            addr: MaybeUninit::uninit(),
        }
    }

    /// 初始化为0的变量，可用于任何需要本类型的接口，不会导致内存安全问题
    /// 但非有效地址的内容会导致对应的socket功能接口返回错误.
    pub fn zeroed(family: i32) -> Self {
        let mut addr = MaybeUninit::<libc::sockaddr_storage>::zeroed();
        unsafe { addr.assume_init_mut() }.ss_family = family as u16;
        Self { addr }
    }

    /// 对应sockaddr::sa_family取值
    pub fn family(&self) -> i32 {
        unsafe { self.addr.assume_init_ref() }.ss_family as i32
    }

    /// 根据family返回端口号，如果不是AF_INET/AF_INET6，则返回None
    pub fn port(&self) -> Option<u16> {
        let addr = unsafe { self.addr.assume_init_ref() };
        match addr.ss_family as i32 {
            libc::AF_INET => {
                let addr = unsafe { &*self.addr.as_ptr().cast::<libc::sockaddr_in>() };
                Some(unsafe { ntohs(addr.sin_port) })
            }
            libc::AF_INET6 => {
                let addr = unsafe { &*self.addr.as_ptr().cast::<libc::sockaddr_in6>() };
                Some(unsafe { ntohs(addr.sin6_port) })
            }
            _ => None,
        }
    }

    /// 根据family返回IP地址，如果不是AF_INET/AF_INET6，则返回None
    pub fn ip<'a>(&self, buf: &'a mut [u8]) -> Option<&'a str> {
        let addr = unsafe { self.addr.assume_init_ref() };
        match addr.ss_family as i32 {
            libc::AF_INET => {
                let addr = unsafe { &*self.addr.as_ptr().cast::<libc::sockaddr_in>() };
                let _ = unsafe {
                    inet_ntop(
                        addr.sin_family.into(),
                        (&addr.sin_addr) as *const _ as *const u8,
                        buf.as_mut_ptr(),
                        buf.len() as u32,
                    )
                };
                unsafe { Some(buf_2_str(buf)) }
            }
            libc::AF_INET6 => {
                let addr = unsafe { &*self.addr.as_ptr().cast::<libc::sockaddr_in6>() };
                let _ = unsafe {
                    inet_ntop(
                        addr.sin6_family.into(),
                        (&addr.sin6_addr) as *const _ as *const u8,
                        buf.as_mut_ptr(),
                        buf.len() as u32,
                    )
                };
                unsafe { Some(buf_2_str(buf)) }
            }
            _ => None,
        }
    }

    /// 根据family返回path，如果不是AF_UNIX，则返回None
    pub fn path(&self) -> Option<&str> {
        let addr = unsafe { self.addr.assume_init_ref() };
        if addr.ss_family as i32 == libc::AF_UNIX {
            let addr = unsafe { &*(addr as *const _ as *const libc::sockaddr_un) };
            let path = unsafe { slice::from_raw_parts(addr.sun_path.as_ptr().cast::<u8>(), addr.sun_path.len()) };
            unsafe { Some(buf_2_str(path)) }
        } else {
            None
        }
    }

    /// 构建AF_UNIX的地址
    pub fn unix(path: &str) -> Result<Self> {
        let mut un = libc::sockaddr_un {
            sun_family: libc::AF_UNIX as u16,
            sun_path: [0; 108],
        };
        if path.len() > 108 {
            return Err(Error::new(err::EINVAL));
        }
        unsafe {
            ptr::copy_nonoverlapping(
                path.as_ptr().cast::<libc::c_char>(),
                un.sun_path.as_mut_ptr(),
                path.len(),
            );
        }
        Ok(Self::new(un))
    }

    /// 构建AF_INET的地址
    pub fn inet(ip: &str, port: u16) -> Result<Self> {
        let mut inet = unsafe { MaybeUninit::<libc::sockaddr_in>::zeroed().assume_init_read() };
        inet.sin_family = libc::AF_INET as u16;
        inet.sin_port = unsafe { htons(port) };
        unsafe {
            Self::inet_addr(libc::AF_INET, ip, mut_void(&mut inet.sin_addr))?;
        }

        Ok(Self::new(inet))
    }

    /// 构建AF_INET6的地址
    pub fn inet6(ip: &str, port: u16) -> Result<Self> {
        let mut inet = unsafe { MaybeUninit::<libc::sockaddr_in6>::zeroed().assume_init_read() };
        inet.sin6_family = libc::AF_INET6 as u16;
        inet.sin6_port = unsafe { htons(port) };
        unsafe {
            Self::inet_addr(libc::AF_INET6, ip, mut_void(&mut inet.sin6_addr))?;
        }

        Ok(Self::new(inet))
    }

    /// 从一个Ip地址和端口号组成的字符串中构建AF_INET/AF_INET6的地址
    /// inet - ipv4:port
    /// inet6 - [ipv6]:port
    /// inet - port, 等同于"0.0.0.0":port
    pub fn inet_from(addr: &str) -> Result<Self> {
        let mut it = addr.rsplitn(2, ':');
        let (port, ip) = match (it.next(), it.next()) {
            (Some(port), Some(ip)) => (port, ip),
            (Some(port), None) => (port, "0.0.0.0"),
            _ => return Err(err::EINVAL.into()),
        };

        let Ok(port) = port.parse::<u16>() else {
            return Err(err::EINVAL.into());
        };

        if !ip.is_empty() && ip.as_bytes()[0_usize] == b'[' {
            Self::inet6(&ip[1..ip.len() - 1], port)
        } else {
            Self::inet(ip, port)
        }
    }

    /// 返回的地址可用于同libc接口的交互.
    pub fn get(&self) -> (&libc::sockaddr, libc::socklen_t) {
        let addr = unsafe { self.addr.assume_init_ref() };
        match addr.ss_family as i32 {
            libc::AF_INET => (
                unsafe { &*self.addr.as_ptr().cast::<libc::sockaddr>() },
                mem::size_of::<libc::sockaddr_in>() as libc::socklen_t,
            ),
            libc::AF_INET6 => (
                unsafe { &*self.addr.as_ptr().cast::<libc::sockaddr>() },
                mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t,
            ),
            libc::AF_UNIX => (
                unsafe { &*self.addr.as_ptr().cast::<libc::sockaddr>() },
                mem::size_of::<libc::sockaddr_un>() as libc::socklen_t,
            ),
            _ => panic!("only support AF_INET/AF_INET6/AF_UNIXDOMAIN"),
        }
    }

    /// 返回的地址可用于同libc接口的交互.
    pub fn get_mut(&mut self) -> (&mut libc::sockaddr, libc::socklen_t) {
        let addr = unsafe { self.addr.assume_init_ref() };
        match addr.ss_family as i32 {
            libc::AF_INET => (
                unsafe { &mut *self.addr.as_mut_ptr().cast::<libc::sockaddr>() },
                mem::size_of::<libc::sockaddr_in>() as libc::socklen_t,
            ),
            libc::AF_INET6 => (
                unsafe { &mut *self.addr.as_mut_ptr().cast::<libc::sockaddr>() },
                mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t,
            ),
            libc::AF_UNIX => (
                unsafe { &mut *self.addr.as_mut_ptr().cast::<libc::sockaddr>() },
                mem::size_of::<libc::sockaddr_un>() as libc::socklen_t,
            ),
            _ => panic!("only support AF_INET/AF_INET6/AF_UNIXDOMAIN"),
        }
    }

    /// 返回的地址可用于同libc接口的交互.
    pub fn get_uninit_mut(&mut self) -> (*mut libc::sockaddr, libc::socklen_t) {
        (
            self.addr.as_mut_ptr().cast::<libc::sockaddr>(),
            mem::size_of_val(&self.addr) as libc::socklen_t,
        )
    }

    fn new<T>(val: T) -> Self {
        let mut addr = MaybeUninit::<libc::sockaddr_storage>::uninit();
        unsafe {
            addr.as_mut_ptr().cast::<T>().write(val);
        }
        Self { addr }
    }

    unsafe fn inet_addr(family: i32, src: &str, dst: *mut libc::c_void) -> Result<()> {
        let mut name = [0_u8; 128];
        if src.len() >= 128 {
            return Err(Error::new(err::EINVAL));
        }
        ptr::copy_nonoverlapping(src.as_ptr(), name.as_mut_ptr(), src.len());

        let ret = inet_pton(family, name.as_ptr(), dst.cast::<u8>());
        if ret == 1 {
            Ok(())
        } else {
            Err(Error::last())
        }
    }
}

/// 输出格式如下:
/// AF_INET类地址为 <ip>:<port>
/// AF_INET6类地址为<[ip]>:<port>
/// AF_UNIX类地址为<path>
impl fmt::Debug for SocketAddr {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let addr = unsafe { self.addr.assume_init_ref() };
        match addr.ss_family as i32 {
            libc::AF_INET => {
                let mut buf = [0_u8; 16];
                let ip = self.ip(&mut buf).unwrap();
                let port = self.port().unwrap();
                f.write_fmt(format_args!("{ip}:{port}"))
            }

            libc::AF_INET6 => {
                let mut buf = [0_u8; 48];
                let ip = self.ip(&mut buf).unwrap();
                let port = self.port().unwrap();
                f.write_fmt(format_args!("[{ip}]:{port}"))
            }
            libc::AF_UNIX => {
                let path = self.path().unwrap();
                f.write_fmt(format_args!("{path}"))
            }
            _ => f.write_fmt(format_args!("unknown family: {}", addr.ss_family)),
        }
    }
}

/// 输出格式如下:
/// AF_INET类地址为 <ip>:<port>
/// AF_INET6类地址为<[ip]>:<port>
/// AF_UNIX类地址为<path>
impl fmt::Display for SocketAddr {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        fmt::Debug::fmt(self, f)
    }
}

unsafe fn buf_2_str(buf: &[u8]) -> &str {
    let mut len = buf.len();
    for (n, c) in buf.iter().enumerate() {
        if *c == 0 {
            len = n;
            break;
        }
    }
    str::from_utf8_unchecked(&buf[..len])
}

#[cfg(test)]
mod test {
    use super::*;
    extern crate std;
    use std::format;

    #[test]
    fn test_inet_addr() {
        let addr = SocketAddr::inet_from("127.0.0.1:200").unwrap();
        let saddr = format!("{}", addr);
        assert_eq!(saddr, "127.0.0.1:200");
        assert_eq!(addr.port(), Some(200_u16));
        let mut buf = [0_u8; 16];
        assert_eq!(addr.ip(&mut buf), Some("127.0.0.1"));
        assert_eq!(addr.path(), None);
        let addr = SocketAddr::inet_from("[::99]:200").unwrap();
        let saddr = format!("{}", addr);
        assert_eq!(saddr, "[::99]:200");
        assert_eq!(addr.port(), Some(200_u16));
        let addr = SocketAddr::inet_from("127.0.0.1");
        assert!(addr.is_err());
        let addr = SocketAddr::inet_from("127.0.0.1:");
        assert!(addr.is_err());
        let addr = SocketAddr::inet_from("200").unwrap();
        assert_eq!(addr.ip(&mut buf), Some("0.0.0.0"));
        assert_eq!(addr.port(), Some(200_u16));
        let addr = SocketAddr::inet_from(":200");
        assert!(addr.is_err());
    }

    #[test]
    fn test_unix_addr() {
        let addr = SocketAddr::unix("/x/y/z").unwrap();
        let saddr = format!("{}", addr);
        let mut buf = [0_u8; 1];
        assert_eq!(saddr, "/x/y/z");
        assert_eq!(addr.port(), None);
        assert_eq!(addr.ip(&mut buf), None);
        assert_eq!(addr.path(), Some("/x/y/z"));
    }
}
